﻿using common.libs;
using common.libs.extends;
using common.server;
using System;
using System.Buffers;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Net;
using System.Net.Sockets;

namespace common.socks5
{
    public class Socks5ServerHandler : ISocks5ServerHandler
    {
        private ConcurrentDictionary<ConnectionKey, AsyncServerUserToken> connections = new(new ConnectionKeyComparer());

        private readonly Socks5MessengerSender socks5MessengerSender;
        private readonly Config config;
        public Socks5ServerHandler(Socks5MessengerSender socks5MessengerSender, Config config)
        {
            this.socks5MessengerSender = socks5MessengerSender;
            this.config = config;
        }

        public Socks5EnumAuthType HandleRequest(Socks5Info data)
        {
            if (!config.ConnectEnable)
            {
                return Socks5EnumAuthType.NotSupported;
            }
            return Socks5EnumAuthType.NoAuth;
        }

        public Socks5EnumAuthState HandleAuth(Socks5Info data)
        {
            if (!config.ConnectEnable)
            {
                return Socks5EnumAuthState.UnKnow;
            }
            return Socks5EnumAuthState.Success;
        }

        public void HndleForward(IConnection connection, Socks5Info data)
        {
            if (config.ConnectEnable)
            {
                ConnectionKey key = new ConnectionKey(connection.ConnectId, data.Id);
                if (connections.TryGetValue(key, out AsyncServerUserToken token))
                {
                    if (data.Data.Length > 0)
                    {
                        token.TargetSocket.Send(data.Data.Span);
                    }
                    else
                    {
                        CloseClientSocket(token);
                    }
                }
            }
        }

        public Socks5EnumResponseCommand HandleCommand(IConnection connection, Socks5Info data)
        {
            if (!config.ConnectEnable)
            {
                return Socks5EnumResponseCommand.Unknow;
            }
            Socks5EnumRequestCommand command = (Socks5EnumRequestCommand)data.Data.Span[1];
            if (command == Socks5EnumRequestCommand.Connect)
            {
                IPEndPoint remoteEndPoint = Socks5Parser.GetRemoteEndPoint(data.Data.Span);
                if (!config.LanConnectEnable && NetworkHelper.IsLan(remoteEndPoint))
                {
                    return Socks5EnumResponseCommand.NetworkError;
                }
                var (responseCommand, socket) = Connect(remoteEndPoint);
                if (responseCommand == Socks5EnumResponseCommand.ConnecSuccess && socket != null)
                {
                    BindTargetReceive(connection, data.Id, socket);
                }
                return responseCommand;
            }
            else if (command == Socks5EnumRequestCommand.Bind)
            {
                return Socks5EnumResponseCommand.CommandNotAllow;
            }
            else if (command == Socks5EnumRequestCommand.UdpAssociate)
            {
                return Socks5EnumResponseCommand.CommandNotAllow;
            }
            return Socks5EnumResponseCommand.CommandNotAllow;
        }
        private (Socks5EnumResponseCommand, Socket) Connect(IPEndPoint remoteEndPoint)
        {
            Socket socket = new Socket(remoteEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
            socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, true);
            Socks5EnumResponseCommand command = Socks5EnumResponseCommand.ConnecSuccess;
            try
            {
                socket.Connect(remoteEndPoint);
                command = Socks5EnumResponseCommand.ConnecSuccess;
            }
            catch (SocketException ex)
            {
                socket = null;
                if (ex.SocketErrorCode == SocketError.ConnectionRefused)
                {
                    command = Socks5EnumResponseCommand.DistReject;
                }
                else if (ex.SocketErrorCode == SocketError.NetworkDown)
                {
                    command = Socks5EnumResponseCommand.NetworkError;
                }
                else if (ex.SocketErrorCode == SocketError.ConnectionReset)
                {
                    command = Socks5EnumResponseCommand.DistReject;
                }
                else if (ex.SocketErrorCode == SocketError.AddressFamilyNotSupported || ex.SocketErrorCode == SocketError.OperationNotSupported)
                {
                    command = Socks5EnumResponseCommand.AddressNotAllow;
                }
            }
            catch (Exception)
            {
                socket = null;
                command = Socks5EnumResponseCommand.ServerError;
            }
            return (command, socket);
        }

        private void Target_IO_Completed(object sender, SocketAsyncEventArgs e)
        {
            switch (e.LastOperation)
            {
                case SocketAsyncOperation.Receive:
                    TargetProcessReceive(e);
                    break;
                default:
                    break;
            }
        }

        private void BindTargetReceive(IConnection Connection, ulong id, Socket targetSocket)
        {
            AsyncServerUserToken token = new AsyncServerUserToken
            {
                Connection = Connection,
                TargetSocket = targetSocket,
                Key = new ConnectionKey(Connection.ConnectId, id)
            };
            try
            {
                connections.TryAdd(token.Key, token);
                SocketAsyncEventArgs readEventArgs = new SocketAsyncEventArgs
                {
                    UserToken = token,
                    SocketFlags = SocketFlags.None,
                };
                readEventArgs.SetBuffer(new byte[config.BufferSize], 0, config.BufferSize);
                readEventArgs.Completed += Target_IO_Completed;
                if (!token.TargetSocket.ReceiveAsync(readEventArgs))
                {
                    TargetProcessReceive(readEventArgs);
                }
            }
            catch (Exception)
            {
                _ = socks5MessengerSender.ResponseClose(token.Key.RequestId, Connection);
            }
        }
        private void TargetProcessReceive(SocketAsyncEventArgs e)
        {
            AsyncServerUserToken token = (AsyncServerUserToken)e.UserToken;
            try
            {
                if (e.BytesTransferred > 0 && e.SocketError == SocketError.Success)
                {
                    int offset = e.Offset;
                    int length = e.BytesTransferred;
                    socks5MessengerSender.Response(new Socks5Info { Id = token.Key.RequestId, Data = e.Buffer.AsMemory(offset, length) }, token.Connection).Wait();


                    if (token.TargetSocket.Available > 0)
                    {
                        var arr = ArrayPool<byte>.Shared.Rent(token.TargetSocket.Available);
                        while (token.TargetSocket.Available > 0)
                        {
                            length = token.TargetSocket.Receive(arr);
                            if (length > 0)
                            {
                                socks5MessengerSender.Response(new Socks5Info { Id = token.Key.RequestId, Data = arr.AsMemory(offset, length) }, token.Connection).Wait();
                            }
                        }
                        ArrayPool<byte>.Shared.Return(arr);
                    }

                    if (!token.TargetSocket.Connected)
                    {
                        CloseClientSocket(e);
                        return;
                    }
                    if (!token.TargetSocket.ReceiveAsync(e))
                    {
                        TargetProcessReceive(e);
                    }
                }
                else
                {
                    CloseClientSocket(e);
                }
            }
            catch (Exception ex)
            {
                CloseClientSocket(e);
                Logger.Instance.DebugError(ex);
            }
        }

        private void CloseClientSocket(SocketAsyncEventArgs e)
        {
            AsyncServerUserToken token = e.UserToken as AsyncServerUserToken;
            IConnection connection = token.Connection;
            token.Clear();
            connections.TryRemove(token.Key, out _);
            e.Dispose();

            _ = socks5MessengerSender.ResponseClose(token.Key.RequestId, connection);
        }
        private void CloseClientSocket(AsyncServerUserToken token)
        {
            token.Clear();
            connections.TryRemove(token.Key, out _);
        }

    }

    public class AsyncServerUserToken
    {
        public ConnectionKey Key { get; set; }
        public IConnection Connection { get; set; }
        public Socket TargetSocket { get; set; }
        public short SyncCount { get; set; } = 0;

        public void Clear()
        {
            //Connection = null;
            TargetSocket?.SafeClose();
            //TargetSocket = null;
        }
    }

    public class ConnectionKeyComparer : IEqualityComparer<ConnectionKey>
    {
        public bool Equals(ConnectionKey x, ConnectionKey y)
        {
            return x.RequestId == y.RequestId && x.ConnectId == y.ConnectId;
        }

        public int GetHashCode(ConnectionKey obj)
        {
            return 0;
        }
    }

    public readonly struct ConnectionKey
    {
        public readonly ulong RequestId { get; }
        public readonly ulong ConnectId { get; }

        public ConnectionKey(ulong connectId, ulong requestId)
        {
            ConnectId = connectId;
            RequestId = requestId;
        }
    }
}
